SR Inference (Inline)¶

Run DEM-conditioned super-resolution inference with fixed default settings, including chip-level and mosaic-level diagnostics against bilinear baseline.

0) Imports and Notebook Root¶

In [1]:
# Standard library + numerical stack for inference workflow.
import os
import math
import json
from pathlib import Path

# Array/data/geo stack.
import numpy as np
import pandas as pd
import rasterio
import tensorflow as tf
import matplotlib.pyplot as plt

# Shared diagnostics and plotting used by training notebooks.
import t02.results as results

PROJECT_ROOT = Path("/workspace").resolve()
if PROJECT_ROOT.exists():
    os.chdir(PROJECT_ROOT)

print(f"cwd set to {Path.cwd()}")
2026-02-20 19:36:10.969625: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-20 19:36:10.969674: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-20 19:36:10.970516: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
cwd set to /workspace

1) Parameters (Default Behavior Preserved)¶

In [2]:
# Set fixed notebook parameters (no env overrides).
dem_fp = Path("_inputs/RSSHydro/dudelange/002/DEM.tif")
depth_lores_fp = Path("_inputs/RSSHydro/dudelange/032/ResultA.tif")
depth_hires_valid_fp = Path("_inputs/RSSHydro/dudelange/002/ResultA.tif")
model_fp = Path("train_outputs/4690176_0_1770580046_train_base_16/train_run/model_infer.keras")

# Keep default output behavior.
write_inference_tiff = True

# Keep default preprocessing and postprocessing behavior.
PRE_RESAMPLE_METHOD = "bilinear"
POST_RESAMPLE_METHOD = "bilinear"
ALLOWED_RESAMPLE_METHODS = (
    "nearest",
    "bilinear",
    "bicubic",
    "area",
    "lanczos3",
    "lanczos5",
)

# Use feathered inference only (default overlap derived from existing notebook behavior).
FEATHER_OVERLAP_LR = 4

# Keep dry/wet diagnostics threshold and low-depth masking defaults.
DRY_DEPTH_THRESH_M = float(results.DEFAULT_DRY_DEPTH_THRESH_M)
APPLY_LOW_DEPTH_MASK = True
LOW_DEPTH_MASK_M = float(DRY_DEPTH_THRESH_M)

# Validate static inputs early.
assert dem_fp.exists(), f"DEM file not found: {dem_fp}"
assert depth_lores_fp.exists(), f"Lo-res depth file not found: {depth_lores_fp}"
assert depth_hires_valid_fp.exists(), f"Hi-res validation depth file not found: {depth_hires_valid_fp}"
assert model_fp.exists(), f"Model file not found: {model_fp}"
assert PRE_RESAMPLE_METHOD in ALLOWED_RESAMPLE_METHODS
assert POST_RESAMPLE_METHOD in ALLOWED_RESAMPLE_METHODS
assert FEATHER_OVERLAP_LR >= 0

print(f"DEM: {dem_fp}")
print(f"LR depth: {depth_lores_fp}")
print(f"HR valid depth: {depth_hires_valid_fp}")
print(f"Model: {model_fp}")
print(f"write_inference_tiff={int(write_inference_tiff)}")
print(f"PRE_RESAMPLE_METHOD={PRE_RESAMPLE_METHOD}, POST_RESAMPLE_METHOD={POST_RESAMPLE_METHOD}")
print(f"FEATHER_OVERLAP_LR={FEATHER_OVERLAP_LR}")
print(f"APPLY_LOW_DEPTH_MASK={int(APPLY_LOW_DEPTH_MASK)}, LOW_DEPTH_MASK_M={LOW_DEPTH_MASK_M:.6f}")
DEM: _inputs/RSSHydro/dudelange/002/DEM.tif
LR depth: _inputs/RSSHydro/dudelange/032/ResultA.tif
HR valid depth: _inputs/RSSHydro/dudelange/002/ResultA.tif
Model: train_outputs/4690176_0_1770580046_train_base_16/train_run/model_infer.keras
write_inference_tiff=1
PRE_RESAMPLE_METHOD=bilinear, POST_RESAMPLE_METHOD=bilinear
FEATHER_OVERLAP_LR=4
APPLY_LOW_DEPTH_MASK=1, LOW_DEPTH_MASK_M=0.010000

2) Load Weights and Training Parameters¶

In [3]:
# Load the exported inference model with custom-activation fallback.
try:
    model = tf.keras.models.load_model(model_fp, compile=False)
except Exception as exc:
    if "bounded_sigmoid" not in str(exc):
        raise
    from t02.train_entry.train_base import bounded_sigmoid
    model = tf.keras.models.load_model(
        model_fp,
        compile=False,
        custom_objects={
            "bounded_sigmoid": bounded_sigmoid,
            "t02>bounded_sigmoid": bounded_sigmoid,
        },
    )

model.trainable = False
print(f"Loaded model: {model_fp}")
INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: NVIDIA RTX 4000 Ada Generation Laptop GPU, compute capability 8.9
Loaded model: train_outputs/4690176_0_1770580046_train_base_16/train_run/model_infer.keras
In [4]:
# Load train_config.json co-located with model weights.
train_config_fp = model_fp.parent / "train_config.json"
if not train_config_fp.exists():
    raise FileNotFoundError(f"train_config.json not found at {train_config_fp}")

train_cfg = json.loads(train_config_fp.read_text())
print(f"Loaded train config from {train_config_fp}")

# Validate required config keys.
required_keys = {"upscale", "input_shape", "max_depth"}
missing_keys = sorted(required_keys.difference(train_cfg.keys()))
if missing_keys:
    raise KeyError(f"train_config missing required keys: {missing_keys}")

# Resolve geometry from train config.
SCALE = int(train_cfg["upscale"])
input_shape = train_cfg["input_shape"]
if not isinstance(input_shape, (list, tuple)) or len(input_shape) < 2:
    raise AssertionError(f"input_shape must have at least [h, w, ...]; got {input_shape}")

LR_TILE = int(input_shape[0])
lr_w = int(input_shape[1])
if LR_TILE != lr_w:
    raise AssertionError(f"input_shape must be square for tiled inference; got {input_shape}")
HR_TILE = LR_TILE * SCALE

# Resolve preprocessing parameters from train config.
MAX_DEPTH = float(train_cfg["max_depth"])
DEPTH_LOG_DENOM = float(np.log1p(MAX_DEPTH))

dem_stats_cfg = train_cfg.get("dem_stats")
if train_cfg.get("dem_pct_clip") is not None:
    DEM_PCT_CLIP = float(train_cfg["dem_pct_clip"])
elif isinstance(dem_stats_cfg, dict) and dem_stats_cfg.get("p_clip") is not None:
    DEM_PCT_CLIP = float(dem_stats_cfg["p_clip"])
else:
    raise KeyError("train_config must include dem_pct_clip or dem_stats.p_clip")

# Reuse training DEM stats when available; otherwise compute from inference raster.
DEM_REF_STATS = None
if isinstance(dem_stats_cfg, dict):
    required_dem_stats = {"p_clip", "dem_min", "dem_max"}
    if required_dem_stats.issubset(dem_stats_cfg.keys()):
        DEM_REF_STATS = {k: float(dem_stats_cfg[k]) for k in sorted(required_dem_stats)}

# Resolve labels for downstream charts.
MODEL_NAME = str(train_cfg.get("model_name") or "Model").strip()
LOSS_LABEL = str(
    train_cfg.get("loss_plot_label")
    or train_cfg.get("loss_label")
    or train_cfg.get("loss_name")
    or "Loss"
).strip()
MODEL_SERIES_LABEL = f"SR ({MODEL_NAME})"

# Guardrails for config consistency.
assert SCALE > 0, f"SCALE must be > 0; got {SCALE}"
assert LR_TILE > 0, f"LR_TILE must be > 0; got {LR_TILE}"
assert HR_TILE == LR_TILE * SCALE
assert MAX_DEPTH > 0, f"MAX_DEPTH must be > 0; got {MAX_DEPTH}"
assert 0 < DEM_PCT_CLIP <= 100
assert 0 <= FEATHER_OVERLAP_LR < LR_TILE, (
    f"FEATHER_OVERLAP_LR must be in [0, {LR_TILE}); got {FEATHER_OVERLAP_LR}"
)
assert LOW_DEPTH_MASK_M >= 0.0
assert LOW_DEPTH_MASK_M <= MAX_DEPTH

print("Resolved from train_config.json:")
print(f"  SCALE={SCALE}, LR_TILE={LR_TILE}, HR_TILE={HR_TILE}")
print(f"  MAX_DEPTH={MAX_DEPTH:.6f}, DEM_PCT_CLIP={DEM_PCT_CLIP:.6f}")
print(f"  MODEL_SERIES_LABEL={MODEL_SERIES_LABEL}")
print(f"  LOSS_LABEL={LOSS_LABEL}")
if DEM_REF_STATS is not None:
    print(f"  DEM_REF_STATS={DEM_REF_STATS}")
else:
    print("  DEM_REF_STATS=None (will compute from inference DEM)")
Loaded train config from train_outputs/4690176_0_1770580046_train_base_16/train_run/train_config.json
Resolved from train_config.json:
  SCALE=16, LR_TILE=32, HR_TILE=512
  MAX_DEPTH=5.000000, DEM_PCT_CLIP=95.000000
  MODEL_SERIES_LABEL=SR (ResUNet_DEM_16)
  LOSS_LABEL=MAE
  DEM_REF_STATS=None (will compute from inference DEM)

3) Load and Validate Raw Rasters¶

In [5]:
# Open rasters and enforce grid compatibility checks before preprocessing.
with (
    rasterio.open(dem_fp) as dem_src,
    rasterio.open(depth_lores_fp) as lr_src,
    rasterio.open(depth_hires_valid_fp) as hr_src,
):
    # Ensure CRS consistency for all sources.
    if dem_src.crs != hr_src.crs or lr_src.crs != hr_src.crs:
        raise AssertionError("CRS mismatch between rasters")

    # Ensure bounds consistency for all sources.
    dem_bounds = tuple(float(v) for v in dem_src.bounds)
    lr_bounds = tuple(float(v) for v in lr_src.bounds)
    hr_bounds = tuple(float(v) for v in hr_src.bounds)
    if not all(np.isclose(a, b, atol=1e-6) for a, b in zip(dem_bounds, hr_bounds)):
        raise AssertionError(f"DEM bounds {dem_bounds} != HR bounds {hr_bounds}")
    if not all(np.isclose(a, b, atol=1e-6) for a, b in zip(lr_bounds, hr_bounds)):
        raise AssertionError(f"LR bounds {lr_bounds} != HR bounds {hr_bounds}")

    # Ensure square pixels in each raster.
    if not np.isclose(abs(dem_src.res[0]), abs(dem_src.res[1])):
        raise AssertionError(f"DEM pixels are not square: res={dem_src.res}")
    if not np.isclose(abs(lr_src.res[0]), abs(lr_src.res[1])):
        raise AssertionError(f"LR pixels are not square: res={lr_src.res}")
    if not np.isclose(abs(hr_src.res[0]), abs(hr_src.res[1])):
        raise AssertionError(f"HR pixels are not square: res={hr_src.res}")

    # Read masked arrays first so nodata issues are explicit.
    dem_read = dem_src.read(1, masked=True)
    lr_read = lr_src.read(1, masked=True)
    hr_read = hr_src.read(1, masked=True)

    # Reject nodata/masked values and NaNs.
    if np.ma.isMaskedArray(dem_read) and np.any(dem_read.mask):
        raise AssertionError("DEM contains nodata/masked values")
    if np.ma.isMaskedArray(lr_read) and np.any(lr_read.mask):
        raise AssertionError("LR depth contains nodata/masked values")
    if np.ma.isMaskedArray(hr_read) and np.any(hr_read.mask):
        raise AssertionError("HR depth contains nodata/masked values")

    dem_raw = np.asarray(dem_read, dtype=np.float32)
    lr_raw = np.asarray(lr_read, dtype=np.float32)
    hr_raw = np.asarray(hr_read, dtype=np.float32)

    if np.isnan(dem_raw).any() or np.isnan(lr_raw).any() or np.isnan(hr_raw).any():
        raise AssertionError("Input rasters contain NaNs")

    # Keep original profiles for export and geometry reporting.
    dem_profile = dem_src.profile.copy()
    lr_profile = lr_src.profile.copy()
    hr_profile = hr_src.profile.copy()

    dem_res = (abs(float(dem_src.res[0])), abs(float(dem_src.res[1])))
    lr_res_raw = (abs(float(lr_src.res[0])), abs(float(lr_src.res[1])))
    hr_res = (abs(float(hr_src.res[0])), abs(float(hr_src.res[1])))

# Apply basic value-range checks used by the original notebook.
if np.min(lr_raw) < 0 or np.min(hr_raw) < 0:
    raise AssertionError("Depth rasters contain negative values")
if np.max(lr_raw) > 15 or np.max(hr_raw) > 15:
    raise AssertionError("Depth rasters exceed 15m max depth")
if np.min(dem_raw) < 0 or np.max(dem_raw) > 5000:
    raise AssertionError("DEM raster outside expected range [0, 5000]")

# Report raw geometry and model-scale compatibility.
raw_shape_scale = (hr_raw.shape[0] / lr_raw.shape[0], hr_raw.shape[1] / lr_raw.shape[1])
raw_res_scale = (lr_res_raw[0] / hr_res[0], lr_res_raw[1] / hr_res[1])
print(f"Raw shapes HR/LR: {hr_raw.shape} / {lr_raw.shape}")
print(
    "Raw resolutions DEM/LR/HR: "
    f"DEM=({dem_res[0]:.6g}, {dem_res[1]:.6g}) "
    f"LR=({lr_res_raw[0]:.6g}, {lr_res_raw[1]:.6g}) "
    f"HR=({hr_res[0]:.6g}, {hr_res[1]:.6g})"
)
print(
    f"Raw HR:LR scale from shape (h,w)=({raw_shape_scale[0]:.2f}, {raw_shape_scale[1]:.2f}), "
    f"from resolution (x,y)=({raw_res_scale[0]:.2f}, {raw_res_scale[1]:.2f})"
)

if not np.isclose(raw_res_scale[0], SCALE) or not np.isclose(raw_res_scale[1], SCALE):
    print(
        f"WARNING: LR/HR resolution ratio (x,y)=({raw_res_scale[0]:.2f}, {raw_res_scale[1]:.2f}) != SCALE={SCALE}."
    )
    print("         HR depth and DEM will be resampled to match model scale.")
Raw shapes HR/LR: (2030, 2090) / (203, 209)
Raw resolutions DEM/LR/HR: DEM=(3, 3) LR=(30, 30) HR=(3, 3)
Raw HR:LR scale from shape (h,w)=(10.00, 10.00), from resolution (x,y)=(10.00, 10.00)
WARNING: LR/HR resolution ratio (x,y)=(10.00, 10.00) != SCALE=16.
         HR depth and DEM will be resampled to match model scale.

4) Plot Raw Rasters¶

In [6]:
# Plot histograms + rasters for raw LR depth, raw HR depth, and raw DEM.
plot_specs_raw = [
    ("LR depth (raw)", lr_raw, "viridis", True, DRY_DEPTH_THRESH_M, lr_res_raw),
    ("HR depth (raw)", hr_raw, "viridis", True, DRY_DEPTH_THRESH_M, hr_res),
    ("DEM (raw)", dem_raw, "terrain", False, None, dem_res),
]

fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(10, 12))

for row_idx, (title, arr, cmap, use_dry_mask, dry_thresh, res_xy) in enumerate(plot_specs_raw):
    arr = np.asarray(arr, dtype=np.float32)
    vals = arr[np.isfinite(arr)]

    ax_hist = axes[row_idx, 0]
    ax_raster = axes[row_idx, 1]

    ax_hist.hist(vals, bins=60, color="steelblue", alpha=0.9)
    if use_dry_mask:
        ax_hist.axvline(dry_thresh, color="red", linestyle="--", linewidth=1.5)
    ax_hist.set_title(f"{title} histogram")
    ax_hist.set_xlabel("Value")
    ax_hist.set_ylabel("Count")
    ax_hist.grid(color="lightgrey", linestyle="-", linewidth=0.7)

    ax_hist.text(
        0.98,
        0.95,
        (
            f"shape: {arr.shape}\n"
            f"res(x,y): ({res_xy[0]:.6g}, {res_xy[1]:.6g})\n"
            f"min: {vals.min():.3f}\n"
            f"max: {vals.max():.3f}\n"
            f"mean: {vals.mean():.3f}\n"
            f"std: {vals.std():.3f}"
        ),
        transform=ax_hist.transAxes,
        fontsize=9,
        verticalalignment="top",
        horizontalalignment="right",
    )

    raster_arr = np.ma.masked_where(arr < dry_thresh, arr) if use_dry_mask else arr
    im = ax_raster.imshow(raster_arr, cmap=cmap)
    ax_raster.set_title(f"{title} raster")
    ax_raster.set_axis_off()
    fig.colorbar(im, ax=ax_raster, fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()
No description has been provided for this image

5) Pre-process Rasters for Model Input¶

In [15]:
# Keep raw rasters for later evaluation and create model-space normalized tensors.
lr_depth_raw_m = np.asarray(lr_raw, dtype=np.float32)
hr_depth_raw_m = np.asarray(hr_raw, dtype=np.float32)
dem_raw_m = np.asarray(dem_raw, dtype=np.float32)

raw_hr_shape = tuple(int(v) for v in hr_depth_raw_m.shape)
raw_lr_shape = tuple(int(v) for v in lr_depth_raw_m.shape)

# Build model-space HR shape from raw LR and fixed SCALE.
target_hr_h = raw_lr_shape[0] * SCALE
target_hr_w = raw_lr_shape[1] * SCALE
print("Preprocessing strategy: keep LR raw; resample HR/DEM to model SCALE.")
print(f"  pre-resample method (HR/DEM->model): {PRE_RESAMPLE_METHOD}")
print(f"  post-resample method (SR->raw HR): {POST_RESAMPLE_METHOD}")
print(f"  raw LR shape/res: {raw_lr_shape} / ({lr_res_raw[0]:.6g}, {lr_res_raw[1]:.6g})")
print(f"  raw HR shape/res: {raw_hr_shape} / ({hr_res[0]:.6g}, {hr_res[1]:.6g})")
print(f"  model HR target from LR*SCALE: {(target_hr_h, target_hr_w)} (SCALE={SCALE})")
Preprocessing strategy: keep LR raw; resample HR/DEM to model SCALE.
  pre-resample method (HR/DEM->model): bilinear
  post-resample method (SR->raw HR): bilinear
  raw LR shape/res: (203, 209) / (30, 30)
  raw HR shape/res: (3248, 3344) / (3, 3)
  model HR target from LR*SCALE: (3248, 3344) (SCALE=16)
In [16]:
# Resample HR depth and DEM to model-space HR grid.
hr_tensor = tf.convert_to_tensor(hr_depth_raw_m[None, ..., None], dtype=tf.float32)
dem_tensor = tf.convert_to_tensor(dem_raw_m[None, ..., None], dtype=tf.float32)
use_antialias_pre = PRE_RESAMPLE_METHOD != "nearest"

hr_model_depth_m = tf.image.resize(
    hr_tensor,
    size=(target_hr_h, target_hr_w),
    method=PRE_RESAMPLE_METHOD,
    antialias=use_antialias_pre,
)[0, ..., 0].numpy().astype(np.float32, copy=False)

dem_model_raw_m = tf.image.resize(
    dem_tensor,
    size=(target_hr_h, target_hr_w),
    method=PRE_RESAMPLE_METHOD,
    antialias=use_antialias_pre,
)[0, ..., 0].numpy().astype(np.float32, copy=False)

# Normalize depth with clip + log1p (same transform as training).
sat_lr = float(np.mean(lr_depth_raw_m >= MAX_DEPTH))
sat_hr_raw = float(np.mean(hr_depth_raw_m >= MAX_DEPTH))
sat_hr_model = float(np.mean(hr_model_depth_m >= MAX_DEPTH))
print(
    "Depth saturation @ MAX_DEPTH: "
    f"LR(raw)={sat_lr:.2%}, HR(raw)={sat_hr_raw:.2%}, HR(model-resampled)={sat_hr_model:.2%}"
)

lr_norm = np.log1p(np.clip(lr_depth_raw_m, 0.0, MAX_DEPTH)) / DEPTH_LOG_DENOM
hr_norm_raw = np.log1p(np.clip(hr_depth_raw_m, 0.0, MAX_DEPTH)) / DEPTH_LOG_DENOM
hr_raw = np.log1p(np.clip(hr_model_depth_m, 0.0, MAX_DEPTH)) / DEPTH_LOG_DENOM

lr_norm = np.clip(lr_norm, 0.0, 1.0).astype(np.float32, copy=False)
hr_norm_raw = np.clip(hr_norm_raw, 0.0, 1.0).astype(np.float32, copy=False)
hr_raw = np.clip(hr_raw, 0.0, 1.0).astype(np.float32, copy=False)

# Normalize DEM with either train-config stats or inference DEM-derived stats.
if DEM_REF_STATS is None:
    dem_for_stats = np.clip(dem_model_raw_m, 0.0, None).astype(np.float32, copy=False)
    p_clip = float(np.nanpercentile(dem_for_stats, DEM_PCT_CLIP))
    dem_for_stats = np.clip(dem_for_stats, 0.0, p_clip)
    dem_min = float(np.nanmin(dem_for_stats))
    dem_max = float(np.nanmax(dem_for_stats))
else:
    required_dem_keys = {"p_clip", "dem_min", "dem_max"}
    missing_dem_keys = required_dem_keys.difference(DEM_REF_STATS.keys())
    if missing_dem_keys:
        raise AssertionError(f"DEM_REF_STATS missing keys: {sorted(missing_dem_keys)}")
    p_clip = float(DEM_REF_STATS["p_clip"])
    dem_min = float(DEM_REF_STATS["dem_min"])
    dem_max = float(DEM_REF_STATS["dem_max"])

dem_range = dem_max - dem_min
if not np.isfinite(p_clip) or not np.isfinite(dem_min) or not np.isfinite(dem_max):
    raise AssertionError("DEM normalization stats must be finite")
if dem_range <= 0:
    raise AssertionError(f"DEM range must be > 0; got min={dem_min}, max={dem_max}")

dem_clipped = np.clip(np.asarray(dem_model_raw_m, dtype=np.float32), 0.0, p_clip)
dem_norm = (dem_clipped - dem_min) / dem_range
dem_norm = np.clip(dem_norm, 0.0, 1.0).astype(np.float32, copy=False)
dem_stats = {"p_clip": p_clip, "dem_min": dem_min, "dem_max": dem_max}

if DEM_REF_STATS is None:
    print(f"DEM stats computed from model-space DEM: {dem_stats}")
else:
    print(f"DEM stats reused from train_config: {dem_stats}")

# Keep LR on raw grid and derive model-space HR resolution.
lr_res_model = lr_res_raw
hr_res_model = (lr_res_model[0] / SCALE, lr_res_model[1] / SCALE)

assert dem_norm.shape == hr_raw.shape
assert lr_norm.shape == raw_lr_shape
assert hr_norm_raw.shape == raw_hr_shape

# Report geometry before and after model-space alignment.
pre_shape_scale = (raw_hr_shape[0] / raw_lr_shape[0], raw_hr_shape[1] / raw_lr_shape[1])
pre_res_scale = (lr_res_raw[0] / hr_res[0], lr_res_raw[1] / hr_res[1])
post_shape_scale = (hr_raw.shape[0] / lr_norm.shape[0], hr_raw.shape[1] / lr_norm.shape[1])
post_res_scale = (lr_res_model[0] / hr_res_model[0], lr_res_model[1] / hr_res_model[1])

print("Pre-resample geometry (raw inputs):")
print(f"  shapes HR/LR: {raw_hr_shape} / {raw_lr_shape}")
print(f"  resolutions HR/LR: ({hr_res[0]:.6g}, {hr_res[1]:.6g}) / ({lr_res_raw[0]:.6g}, {lr_res_raw[1]:.6g})")
print(f"  scale from shape (h,w): ({pre_shape_scale[0]:.2f}, {pre_shape_scale[1]:.2f})")
print(f"  scale from resolution (x,y): ({pre_res_scale[0]:.2f}, {pre_res_scale[1]:.2f})")

print("Post-resample geometry (model input):")
print(f"  shapes HR/LR: {hr_raw.shape} / {lr_norm.shape}")
print(f"  resolutions HR/LR: ({hr_res_model[0]:.6g}, {hr_res_model[1]:.6g}) / ({lr_res_model[0]:.6g}, {lr_res_model[1]:.6g})")
print(f"  scale from shape (h,w): ({post_shape_scale[0]:.2f}, {post_shape_scale[1]:.2f})")
print(f"  scale from resolution (x,y): ({post_res_scale[0]:.2f}, {post_res_scale[1]:.2f})")

# Set model-valid extent before tile padding.
crop_h, crop_w = hr_raw.shape
print(f"Model-space valid extent (before tiling pad): {(crop_h, crop_w)}")

# Pad HR/DEM/LR arrays to exact tile multiples for windowed inference.
pad_h = (int(math.ceil(crop_h / HR_TILE)) * HR_TILE) - crop_h
pad_w = (int(math.ceil(crop_w / HR_TILE)) * HR_TILE) - crop_w

hr_pad = np.pad(hr_raw, ((0, pad_h), (0, pad_w)), mode="constant", constant_values=0.0)
dem_pad = np.pad(dem_norm, ((0, pad_h), (0, pad_w)), mode="constant", constant_values=0.0)

lr_pad_h_target = hr_pad.shape[0] // SCALE
lr_pad_w_target = hr_pad.shape[1] // SCALE
lr_pad_h_add = lr_pad_h_target - lr_norm.shape[0]
lr_pad_w_add = lr_pad_w_target - lr_norm.shape[1]
assert lr_pad_h_add >= 0 and lr_pad_w_add >= 0

lr_pad = np.pad(
    lr_norm,
    ((0, lr_pad_h_add), (0, lr_pad_w_add)),
    mode="constant",
    constant_values=0.0,
)

pad_shape_scale = (hr_pad.shape[0] / lr_pad.shape[0], hr_pad.shape[1] / lr_pad.shape[1])
print(f"Padded shapes HR/LR: {hr_pad.shape} / {lr_pad.shape}")
print(
    f"Padded HR:LR scale from shape (h,w)=({pad_shape_scale[0]:.2f}, {pad_shape_scale[1]:.2f}) (expected {SCALE})"
)
Depth saturation @ MAX_DEPTH: LR(raw)=0.00%, HR(raw)=0.00%, HR(model-resampled)=0.00%
DEM stats computed from model-space DEM: {'p_clip': 409.72169342041013, 'dem_min': 226.19998168945312, 'dem_max': 409.7216796875}
Pre-resample geometry (raw inputs):
  shapes HR/LR: (3248, 3344) / (203, 209)
  resolutions HR/LR: (3, 3) / (30, 30)
  scale from shape (h,w): (16.00, 16.00)
  scale from resolution (x,y): (10.00, 10.00)
Post-resample geometry (model input):
  shapes HR/LR: (3248, 3344) / (203, 209)
  resolutions HR/LR: (1.875, 1.875) / (30, 30)
  scale from shape (h,w): (16.00, 16.00)
  scale from resolution (x,y): (16.00, 16.00)
Model-space valid extent (before tiling pad): (3248, 3344)
Padded shapes HR/LR: (3584, 3584) / (224, 224)
Padded HR:LR scale from shape (h,w)=(16.00, 16.00) (expected 16)

6) Plot Pre-processed Rasters¶

In [17]:
# Plot histograms + rasters after normalization and model-space alignment.
dry_thresh_norm = float(np.log1p(np.clip(DRY_DEPTH_THRESH_M, 0.0, MAX_DEPTH)) / DEPTH_LOG_DENOM)
plot_specs_norm = [
    ("LR depth (normalized, raw grid)", lr_norm, "viridis", True, dry_thresh_norm, lr_res_model),
    ("HR depth (normalized, model grid)", hr_raw, "viridis", True, dry_thresh_norm, hr_res_model),
    ("DEM (normalized, model grid)", dem_norm, "terrain", False, None, hr_res_model),
]

fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(10, 12))

for row_idx, (title, arr, cmap, use_dry_mask, dry_thresh, res_xy) in enumerate(plot_specs_norm):
    arr = np.asarray(arr, dtype=np.float32)
    vals = arr[np.isfinite(arr)]

    ax_hist = axes[row_idx, 0]
    ax_raster = axes[row_idx, 1]

    ax_hist.hist(vals, bins=60, color="steelblue", alpha=0.9)
    if use_dry_mask:
        ax_hist.axvline(dry_thresh, color="red", linestyle="--", linewidth=1.5)
    ax_hist.set_title(f"{title} histogram")
    ax_hist.set_xlabel("Value")
    ax_hist.set_ylabel("Count")
    ax_hist.grid(color="lightgrey", linestyle="-", linewidth=0.7)

    ax_hist.text(
        0.98,
        0.95,
        (
            f"shape: {arr.shape}\n"
            f"res(x,y): ({res_xy[0]:.6g}, {res_xy[1]:.6g})\n"
            f"min: {vals.min():.3f}\n"
            f"max: {vals.max():.3f}\n"
            f"mean: {vals.mean():.3f}\n"
            f"std: {vals.std():.3f}"
        ),
        transform=ax_hist.transAxes,
        fontsize=9,
        verticalalignment="top",
        horizontalalignment="right",
    )

    raster_arr = np.ma.masked_where(arr < dry_thresh, arr) if use_dry_mask else arr
    im = ax_raster.imshow(raster_arr, cmap=cmap)
    ax_raster.set_title(f"{title} raster")
    ax_raster.set_axis_off()
    fig.colorbar(im, ax=ax_raster, fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()
No description has been provided for this image

7) Windowed Inference (Feathered)¶

In [18]:
# Run non-overlap chip inference first to build cache and chip diagnostics arrays.
hr_pad_h, hr_pad_w = hr_pad.shape
tile_pred_cache = {}

nonoverlap_y_starts = list(range(0, hr_pad_h, HR_TILE))
nonoverlap_x_starts = list(range(0, hr_pad_w, HR_TILE))
sr_pad_nonoverlap = np.zeros_like(hr_pad, dtype=np.float32)

print(
    f"Running non-overlap per-chip inference on {len(nonoverlap_y_starts) * len(nonoverlap_x_starts)} chips..."
)

for y0 in nonoverlap_y_starts:
    for x0 in nonoverlap_x_starts:
        key = (int(y0), int(x0))

        if key not in tile_pred_cache:
            # Map HR origin to LR origin on fixed SCALE grid.
            lr_y0 = y0 // SCALE
            lr_x0 = x0 // SCALE

            # Slice aligned LR-depth and HR-DEM tiles.
            lr_tile = lr_pad[lr_y0 : lr_y0 + LR_TILE, lr_x0 : lr_x0 + LR_TILE]
            dem_tile = dem_pad[y0 : y0 + HR_TILE, x0 : x0 + HR_TILE]

            if lr_tile.shape != (LR_TILE, LR_TILE):
                raise AssertionError(f"LR tile shape mismatch: {lr_tile.shape}")
            if dem_tile.shape != (HR_TILE, HR_TILE):
                raise AssertionError(f"DEM tile shape mismatch: {dem_tile.shape}")
            if lr_tile.min() < 0.0 or lr_tile.max() > 1.0:
                raise AssertionError("lr_tile values must be in [0,1]")
            if dem_tile.min() < 0.0 or dem_tile.max() > 1.0:
                raise AssertionError("dem_tile values must be in [0,1]")

            # Add channel + batch dims expected by model.
            lr_tile_batched = lr_tile[..., None][None, ...].astype(np.float32, copy=False)
            dem_tile_batched = dem_tile[..., None][None, ...].astype(np.float32, copy=False)

            # Run model and normalize output shape/value range.
            pred = model((lr_tile_batched, dem_tile_batched), training=False)
            pred_np = np.asarray(pred, dtype=np.float32)

            if pred_np.ndim == 4:
                if pred_np.shape[0] != 1:
                    raise AssertionError(f"prediction batch dimension must be 1; got {pred_np.shape}")
                pred_np = pred_np[0]
            elif pred_np.ndim != 3:
                raise AssertionError(f"prediction rank must be 3 or 4; got {pred_np.shape}")

            if tuple(pred_np.shape) != (HR_TILE, HR_TILE, 1):
                raise AssertionError(
                    f"prediction sample shape must be {(HR_TILE, HR_TILE, 1)}; got {pred_np.shape}"
                )
            if not np.all(np.isfinite(pred_np)):
                raise AssertionError("prediction must contain only finite values")

            pred_np = np.clip(pred_np, 0.0, 1.0).astype(np.float32, copy=False)[..., 0]
            tile_pred_cache[key] = pred_np

        sr_pad_nonoverlap[y0 : y0 + HR_TILE, x0 : x0 + HR_TILE] = tile_pred_cache[key]

# Build valid-chip stacks for diagnostics from fully valid (non-padded) chip grid.
valid_tiles_h = crop_h // HR_TILE
valid_tiles_w = crop_w // HR_TILE
if valid_tiles_h == 0 or valid_tiles_w == 0:
    raise ValueError(f"No fully valid chips for diagnostics (crop={(crop_h, crop_w)}, HR_TILE={HR_TILE}).")

n_valid = valid_tiles_h * valid_tiles_w
lowres_chips = np.zeros((n_valid, LR_TILE, LR_TILE, 1), dtype=np.float32)
highres_chips = np.zeros((n_valid, HR_TILE, HR_TILE, 1), dtype=np.float32)
preds_chips = np.zeros((n_valid, HR_TILE, HR_TILE, 1), dtype=np.float32)
chip_coords = []

chip_idx = 0
for ty in range(valid_tiles_h):
    y0 = ty * HR_TILE
    for tx in range(valid_tiles_w):
        x0 = tx * HR_TILE

        lr_y0 = y0 // SCALE
        lr_x0 = x0 // SCALE

        lowres_chips[chip_idx, ..., 0] = lr_pad[lr_y0 : lr_y0 + LR_TILE, lr_x0 : lr_x0 + LR_TILE]
        highres_chips[chip_idx, ..., 0] = hr_pad[y0 : y0 + HR_TILE, x0 : x0 + HR_TILE]
        preds_chips[chip_idx, ..., 0] = tile_pred_cache[(y0, x0)]
        chip_coords.append((y0, x0))
        chip_idx += 1

chip_coords = np.asarray(chip_coords, dtype=np.int32)
print(f"Prepared {chip_idx} valid chips ({valid_tiles_h} x {valid_tiles_w}) for diagnostics.")

# Build feather window grid with forced trailing-edge coverage.
overlap_hr = FEATHER_OVERLAP_LR * SCALE
stride_hr = HR_TILE - overlap_hr
if stride_hr <= 0:
    raise AssertionError(
        f"Feather stride must be > 0; got stride_hr={stride_hr} from FEATHER_OVERLAP_LR={FEATHER_OVERLAP_LR}"
    )

y_starts = list(range(0, max(hr_pad_h - HR_TILE + 1, 1), stride_hr))
x_starts = list(range(0, max(hr_pad_w - HR_TILE + 1, 1), stride_hr))
last_y = hr_pad_h - HR_TILE
last_x = hr_pad_w - HR_TILE
if y_starts[-1] != last_y:
    y_starts.append(last_y)
if x_starts[-1] != last_x:
    x_starts.append(last_x)

# Build symmetric 1D feather ramp for separable 2D blending.
feather_1d = np.ones(HR_TILE, dtype=np.float32)
if overlap_hr > 0:
    ramp = np.linspace(0.0, 1.0, overlap_hr + 2, dtype=np.float32)[1:-1]
    feather_1d[:overlap_hr] = ramp
    feather_1d[-overlap_hr:] = ramp[::-1]
feather_1d = np.clip(feather_1d, 1e-3, 1.0)

accum = np.zeros_like(hr_pad, dtype=np.float32)
weight_sum = np.zeros_like(hr_pad, dtype=np.float32)

print(
    f"Mosaicing with feather windows: {len(y_starts) * len(x_starts)} windows "
    f"(overlap={overlap_hr} px, stride={stride_hr} px)..."
)

for yi, y0 in enumerate(y_starts):
    for xi, x0 in enumerate(x_starts):
        key = (int(y0), int(x0))

        if key not in tile_pred_cache:
            lr_y0 = y0 // SCALE
            lr_x0 = x0 // SCALE
            lr_tile = lr_pad[lr_y0 : lr_y0 + LR_TILE, lr_x0 : lr_x0 + LR_TILE]
            dem_tile = dem_pad[y0 : y0 + HR_TILE, x0 : x0 + HR_TILE]

            if lr_tile.shape != (LR_TILE, LR_TILE):
                raise AssertionError(f"LR tile shape mismatch: {lr_tile.shape}")
            if dem_tile.shape != (HR_TILE, HR_TILE):
                raise AssertionError(f"DEM tile shape mismatch: {dem_tile.shape}")

            lr_tile_batched = lr_tile[..., None][None, ...].astype(np.float32, copy=False)
            dem_tile_batched = dem_tile[..., None][None, ...].astype(np.float32, copy=False)

            pred = model((lr_tile_batched, dem_tile_batched), training=False)
            pred_np = np.asarray(pred, dtype=np.float32)
            if pred_np.ndim == 4:
                pred_np = pred_np[0]
            if tuple(pred_np.shape) != (HR_TILE, HR_TILE, 1):
                raise AssertionError(
                    f"prediction sample shape must be {(HR_TILE, HR_TILE, 1)}; got {pred_np.shape}"
                )
            pred_np = np.clip(pred_np, 0.0, 1.0).astype(np.float32, copy=False)[..., 0]
            tile_pred_cache[key] = pred_np

        pred_np = tile_pred_cache[key]

        # Flatten exterior feather edges so the scene boundary is not dimmed.
        wy = feather_1d.copy()
        wx = feather_1d.copy()
        if yi == 0:
            wy[:overlap_hr] = 1.0
        if yi == len(y_starts) - 1:
            wy[-overlap_hr:] = 1.0
        if xi == 0:
            wx[:overlap_hr] = 1.0
        if xi == len(x_starts) - 1:
            wx[-overlap_hr:] = 1.0

        weight = np.outer(wy, wx).astype(np.float32, copy=False)
        accum[y0 : y0 + HR_TILE, x0 : x0 + HR_TILE] += pred_np * weight
        weight_sum[y0 : y0 + HR_TILE, x0 : x0 + HR_TILE] += weight

sr_pad = np.divide(
    accum,
    np.maximum(weight_sum, 1e-6),
    out=np.zeros_like(accum),
    where=weight_sum > 0,
)
print(f"Cached predictions after feather mosaicing: {len(tile_pred_cache)}")

# Crop back to valid model-space extent and bound normalized range.
sr_model = np.clip(sr_pad[:crop_h, :crop_w], 0.0, 1.0)
hr_model_valid = hr_raw[:crop_h, :crop_w]
print(f"Model-space SR/HR shapes: {sr_model.shape} / {hr_model_valid.shape}")

# Resample model-space SR back to raw HR grid for final evaluation.
print("Post-processing: resample model-space SR back to raw HR grid.")
print(f"  method={POST_RESAMPLE_METHOD}, target shape/res={raw_hr_shape} / ({hr_res[0]:.6g}, {hr_res[1]:.6g})")

sr_tensor = tf.convert_to_tensor(sr_model[None, ..., None], dtype=tf.float32)
use_antialias_post = POST_RESAMPLE_METHOD != "nearest"
sr = tf.image.resize(
    sr_tensor,
    size=raw_hr_shape,
    method=POST_RESAMPLE_METHOD,
    antialias=use_antialias_post,
)[0, ..., 0].numpy().astype(np.float32, copy=False)
sr = np.clip(sr, 0.0, 1.0)

# Keep normalized HR target on original raw HR grid for metrics.
hr_valid = np.clip(hr_norm_raw, 0.0, 1.0).astype(np.float32, copy=False)

# Apply low-depth mask on SR output in raw HR grid (default behavior keeps this enabled).
low_depth_mask_norm = float(np.log1p(np.clip(LOW_DEPTH_MASK_M, 0.0, MAX_DEPTH)) / DEPTH_LOG_DENOM)
if APPLY_LOW_DEPTH_MASK:
    low_mask = sr < low_depth_mask_norm
    masked_count = int(np.sum(low_mask))
    sr = np.where(low_mask, 0.0, sr).astype(np.float32, copy=False)
    print(
        "Applied low-depth SR mask: "
        f"threshold={LOW_DEPTH_MASK_M:.6f} m ({low_depth_mask_norm:.6f} norm), "
        f"masked={masked_count:,}/{sr.size:,} ({masked_count/sr.size:.2%})"
    )

print(f"Post-processed SR/HR shapes: {sr.shape} / {hr_valid.shape}")
Running non-overlap per-chip inference on 49 chips...
Prepared 36 valid chips (6 x 6) for diagnostics.
Mosaicing with feather windows: 64 windows (overlap=64 px, stride=448 px)...
Cached predictions after feather mosaicing: 109
Model-space SR/HR shapes: (3248, 3344) / (3248, 3344)
Post-processing: resample model-space SR back to raw HR grid.
  method=bilinear, target shape/res=(3248, 3344) / (3, 3)
Applied low-depth SR mask: threshold=0.010000 m (0.005553 norm), masked=9,907,796/10,861,312 (91.22%)
Post-processed SR/HR shapes: (3248, 3344) / (3248, 3344)

8) Per-chip Performance¶

In [19]:
# Evaluate chip-level SR vs bilinear baseline metrics.
chip_shape_scale = (
    highres_chips.shape[1] / lowres_chips.shape[1],
    highres_chips.shape[2] / lowres_chips.shape[2],
)
print(
    f"Chip geometry: LR chip={lowres_chips.shape[1:3]}, HR chip={highres_chips.shape[1:3]}, "
    f"scale(h,w)=({chip_shape_scale[0]:.2f}, {chip_shape_scale[1]:.2f})"
)

chip_summary, chip_per_sample = results.evaluate_chip_arrays_vs_bilinear(
    lowres_chips=lowres_chips,
    highres_chips=highres_chips,
    preds_chips=preds_chips,
    max_depth=MAX_DEPTH,
    split_name="inference_chips",
    dry_depth_thresh_m=DRY_DEPTH_THRESH_M,
)

print("Per-chip summary:")
print(json.dumps(chip_summary, indent=2, sort_keys=True))

fig_scatter, _ = results.plot_metric_scatter_vs_mean_depth(
    chip_per_sample,
    model_label=MODEL_SERIES_LABEL,
    baseline_label="Bilinear",
    split_summary=chip_summary,
    model_name=MODEL_NAME,
    loss_plot_label=LOSS_LABEL,
)
plt.show()
plt.close(fig_scatter)
Chip geometry: LR chip=(32, 32), HR chip=(512, 512), scale(h,w)=(16.00, 16.00)
Per-chip summary:
{
  "baseline": {
    "CSI": 0.2644336521625519,
    "CSI_001cm": 0.2644336521625519,
    "CSI_050cm": 0.0,
    "CSI_100cm": 0.0,
    "MAE": 0.016119904816150665,
    "PSNR": 27.737504959106445,
    "RMSE": 0.051067858934402466,
    "RMSE_wet": null,
    "RMSE_wet_001cm": 0.12332958728075027,
    "RMSE_wet_050cm": null,
    "RMSE_wet_100cm": null,
    "SSIM": 0.8552564978599548
  },
  "best_epoch": {
    "CSI": 0.27459511160850525,
    "CSI_001cm": 0.27459511160850525,
    "CSI_050cm": 0.0,
    "CSI_100cm": 0.0,
    "MAE": 0.006825117394328117,
    "PSNR": 34.019412994384766,
    "RMSE": 0.023256177082657814,
    "RMSE_wet": null,
    "RMSE_wet_001cm": 0.07845866680145264,
    "RMSE_wet_050cm": null,
    "RMSE_wet_100cm": null,
    "SSIM": 0.8954101204872131
  }
}
No description has been provided for this image
In [24]:
# Plot chip-level statistics and best/worst examples.
fig_chip_scatter, _ = results.plot_chip_stat_scatter(
    chip_per_sample,
    model_label=MODEL_SERIES_LABEL,
    baseline_label="Bilinear",
    model_name=MODEL_NAME,
)
plt.show()
 
No description has been provided for this image

PLOT: chip examples¶

In [25]:
_ = results.plot_best_worst_chip_examples(
    lowres_chips=lowres_chips,
    highres_chips=highres_chips,
    preds_chips=preds_chips,
    max_depth=MAX_DEPTH,
    n_show=3,
    dry_depth_thresh_m=DRY_DEPTH_THRESH_M,
    cmap="cividis",
    chip_ids=chip_coords,
)
Scanned test chips: 36
Retained candidate chips in memory: 6

Worst chips (highest SR MAE):
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
Best chips (lowest SR MAE):
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

9) Mosaic-level Inference, Bilinear Comparison, and Export¶

In [29]:
# Build bilinear baseline on raw HR grid and compare mosaic-level metrics.
baseline_shape_scale = (hr_valid.shape[0] / lr_norm.shape[0], hr_valid.shape[1] / lr_norm.shape[1])
baseline_res_scale = (lr_res_raw[0] / hr_res[0], lr_res_raw[1] / hr_res[1])
print("Full-scene bilinear baseline geometry (raw-grid evaluation):")
print("  baseline method (fixed): bilinear")
print(f"  LR raw shape/res: {lr_norm.shape} / ({lr_res_raw[0]:.6g}, {lr_res_raw[1]:.6g})")
print(f"  HR raw shape/res: {hr_valid.shape} / ({hr_res[0]:.6g}, {hr_res[1]:.6g})")
print(f"  scale from shape (h,w): ({baseline_shape_scale[0]:.2f}, {baseline_shape_scale[1]:.2f})")
print(f"  scale from resolution (x,y): ({baseline_res_scale[0]:.2f}, {baseline_res_scale[1]:.2f})")

baseline_tensor = tf.convert_to_tensor(lr_norm[None, ..., None], dtype=tf.float32)
baseline = tf.image.resize(
    baseline_tensor,
    size=raw_hr_shape,
    method="bilinear",
    antialias=True,
)[0, ..., 0].numpy().astype(np.float32, copy=False)
baseline = np.clip(baseline, 0.0, 1.0)
print(f"  bilinear output shape: {baseline.shape}")
print(f"  bilinear output resolution: ({hr_res[0]:.6g}, {hr_res[1]:.6g})")

# Compute normalized-space metrics for SR and bilinear on the full mosaic.
hr_full = tf.convert_to_tensor(hr_valid[None, ..., None], dtype=tf.float32)
sr_full = tf.convert_to_tensor(sr[None, ..., None], dtype=tf.float32)
bl_full = tf.convert_to_tensor(baseline[None, ..., None], dtype=tf.float32)

sr_metric_tensors = results.compute_per_sample_metrics(hr_full, sr_full)
bl_metric_tensors = results.compute_per_sample_metrics(hr_full, bl_full)

metrics_sr = results.reduce_metric_buffers({k: [v] for k, v in sr_metric_tensors.items()})
metrics_bilinear = results.reduce_metric_buffers({k: [v] for k, v in bl_metric_tensors.items()})

df = pd.DataFrame({"ResUNet": metrics_sr, "Bilinear": metrics_bilinear})
df = df.loc[list(results.METRIC_KEYS), ["ResUNet", "Bilinear"]]
df["delta"] = df["ResUNet"] - df["Bilinear"]
print("Mosaic metric summary computed on raw HR grid.")
df.round(4)
Full-scene bilinear baseline geometry (raw-grid evaluation):
  baseline method (fixed): bilinear
  LR raw shape/res: (203, 209) / (30, 30)
  HR raw shape/res: (3248, 3344) / (3, 3)
  scale from shape (h,w): (16.00, 16.00)
  scale from resolution (x,y): (10.00, 10.00)
  bilinear output shape: (3248, 3344)
  bilinear output resolution: (3, 3)
Mosaic metric summary computed on raw HR grid.
Out[29]:
ResUNet Bilinear delta
MAE 0.0059 0.0145 -0.0086
PSNR 32.0955 24.8302 7.2653
SSIM 0.9067 0.8669 0.0398
RMSE 0.0248 0.0573 -0.0325
RMSE_wet_001cm 0.0848 0.1508 -0.0661
RMSE_wet_050cm NaN NaN NaN
RMSE_wet_100cm NaN NaN NaN
CSI_001cm 0.3097 0.3032 0.0065
CSI_050cm 0.0000 0.0000 0.0000
CSI_100cm 0.0000 0.0000 0.0000
In [30]:
# Plot final mosaic-level comparison diagnostics in depth units.
full_lr = tf.convert_to_tensor(lr_norm[..., None], dtype=tf.float32)
full_hr = tf.convert_to_tensor(hr_valid[..., None], dtype=tf.float32)
full_sr = tf.convert_to_tensor(sr[..., None], dtype=tf.float32)

print("Full-scene inference diagnostics (raw-grid evaluation)")
fig, final_metrics = results.plot_chip_comparison(
    highres=full_hr,
    lowres=full_lr,
    preds=full_sr,
    max_depth=MAX_DEPTH,
    dry_depth_thresh_m=DRY_DEPTH_THRESH_M,
    cmap="cividis",
    lowres_resolution=lr_res_raw,
    highres_resolution=hr_res,
)
plt.show()
plt.close(fig)

tile_label = "full-scene"
print("PSNR between LR and HR image {}: {:.4f}".format(tile_label, final_metrics["lr_psnr"]))
print("SSIM between LR and HR image {}: {:.4f}".format(tile_label, final_metrics["lr_ssim"]))
print("PSNR between HR and SR image {}: {:.4f}".format(tile_label, final_metrics["sr_psnr"]))
print("SSIM between HR and SR image {}: {:.4f}".format(tile_label, final_metrics["sr_ssim"]))
print("MAE between HR and SR image {}: {:.6f} m".format(tile_label, final_metrics["sr_mae_m"]))

final_metrics
Full-scene inference diagnostics (raw-grid evaluation)
No description has been provided for this image
PSNR between LR and HR image full-scene: 26.6534
SSIM between LR and HR image full-scene: 0.9146
PSNR between HR and SR image full-scene: 39.0265
SSIM between HR and SR image full-scene: 0.9479
MAE between HR and SR image full-scene: 0.012357 m
Out[30]:
{'lr_psnr': 26.653362274169922,
 'lr_ssim': 0.9145838022232056,
 'lr_mae_m': 0.04336462914943695,
 'lr_rmse_m': 0.2324351817369461,
 'lr_rmse_wet_m': 0.6266283392906189,
 'lr_bias_m': 0.03487890958786011,
 'lr_wet_pixel_count': 882346,
 'lr_dry_pixel_count': 9978966,
 'sr_psnr': 39.0264778137207,
 'sr_ssim': 0.9478932619094849,
 'sr_mae_m': 0.012356706894934177,
 'sr_rmse_m': 0.05593014508485794,
 'sr_rmse_wet_m': 0.17960862815380096,
 'sr_bias_m': -0.004448903724551201,
 'sr_wet_pixel_count': 882346,
 'sr_dry_pixel_count': 9978966,
 'hr_wet_pixel_count': 882346,
 'hr_dry_pixel_count': 9978966,
 'hr_mean_depth_m': 0.012464739382266998,
 'hr_max_depth_m': 1.0,
 'hr_min_depth_m': 0.0,
 'bl_psnr': 28.806163787841797,
 'bl_ssim': 0.9093832969665527,
 'bl_mae_m': 0.036564625799655914,
 'bl_rmse_m': 0.18141023814678192,
 'bl_rmse_wet_m': 0.5058002471923828,
 'bl_bias_m': 0.03031882829964161,
 'bl_wet_pixel_count': 882346,
 'bl_dry_pixel_count': 9978966}